import torch
import torchvision

loss = torch.nn.TripletMarginLoss()
anchors =  torch.randn(1, 2,requires_grad = True)
positive = torch.randn(1, 2)
negative = torch.randn(1, 2)

out_loss = loss(anchors, positive, negative)
print(out_loss)
out_loss.backward()


